The MNIST digits dataset is a famous dataset of handwritten digit images. You can read more about it at wikipedia or Yann LeCun's page. It's a useful dataset because it provides an example of a pretty simple, straightforward image processing task, for which we know exactly what state of the art accuracy is.
I plan to use this dataset for a couple upcoming machine learning blog posts, and since the first step of pretty much any ML task is 'explore your data,' I figured I would post this first, to have to refer back to, instead of repeating in each subsequent post.
Conveniently, scikit-learn
has a built-in utility for loading this (and other) standard datsets.
In [22]:
import pandas as pd
import matplotlib.pyplot as plt
import os
In [23]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original', data_home='datasets/')
In [24]:
# Convert sklearn 'datasets bunch' object to Pandas DataFrames
y = pd.Series(mnist.target).astype('int').astype('category')
X = pd.DataFrame(mnist.data)
We can see below that our data (X) and target (y) have 70,000 rows, meaning we have information on 70,000 digit images. Our X, or independent variables dataset, has 784 columns, which correspond to the 784 pixel values in a 28-pixel x 28-pixel image (28x28 = 784). Our y, or target, is a single column representing the true digit labels (0-9) for each image.
In [18]:
X.shape, y.shape
Out[18]:
In [46]:
# Change column-names in X to reflect that they are pixel values
num_images = X.shape[1]
X.columns = ['pixel_'+str(x) for x in range(num_images)]
# print first row of X
X.head(1)
Out[46]:
Below we see min, max, mean and most-common pixel-intensity values for our rows/images. As suggested by the first row above, our most common value is 0. In fact even the median is 0, which means over half of our pixels are background/blank space. Makes sense.
In [37]:
X_values = pd.Series(X.values.ravel())
print(" min: {}, \n max: {}, \n mean: {}, \n median: {}, \n most common value: {}".format(X_values.min(),
X_values.max(),
X_values.mean(),
X_values.median(),
X_values.value_counts().idxmax()))
We might wonder if there are only a few distinct pixel values present in the data (e.g. black, white, and a few shades of grey), but in fact we have all 256 values between our min-max of 0-255:
In [47]:
len(np.unique(X.values))
Out[47]:
We can also take a look at the digits images themselves, with matplotlib's handy function pyplot.imshow
().
imshow
accepts a dataset to plot, which it will interpret as pixel values. It also accepts a color-mapping to determine the color each pixel-value should be displayed as.
In the code below, we'll plot our first row/image, using the "reverse grayscale" color-map, to plot 0 (background in this dataset) as white.
In [39]:
# First row is first image
first_image = X.loc[0,:]
first_label = y[0]
# 784 columns correspond to 28x28 image
plottable_image = np.reshape(first_image.values, (28, 28))
# Plot the image
plt.imshow(plottable_image, cmap='gray_r')
plt.title('Digit Label: {}'.format(first_label))
plt.show()
And here's a few more...
In [44]:
images_to_plot = 9
random_indices = random.sample(range(70000), images_to_plot)
sample_images = X.loc[random_indices, :]
sample_labels = y.loc[random_indices]
In [45]:
plt.clf()
plt.style.use('seaborn-muted')
fig, axes = plt.subplots(3,3,
figsize=(5,5),
sharex=True, sharey=True,
subplot_kw=dict(adjustable='box-forced', aspect='equal')) #https://stackoverflow.com/q/44703433/1870832
for i in range(images_to_plot):
# axes (subplot) objects are stored in 2d array, accessed with axes[row,col]
subplot_row = i//3
subplot_col = i%3
ax = axes[subplot_row, subplot_col]
# plot image on subplot
plottable_image = np.reshape(sample_images.iloc[i,:].values, (28,28))
ax.imshow(plottable_image, cmap='gray_r')
ax.set_title('Digit Label: {}'.format(sample_labels.iloc[i]))
ax.set_xbound([0,28])
plt.tight_layout()
plt.show()
One last thing I'd want to check here before moving forward with any classification task, would be to determine how balanced our dataset is. Do we have a pretty even distribution of each digit? Or do we have mostly 7s, for example?
In [6]:
y.value_counts(normalize=True)
Out[6]:
Great. Almost all the digits appear in between 9.7% and 10.4% of our rows. 11.3% of our digits are 1
, which is the most common, and 5
is the least common at 9.0%. Overall a very well-balanced dataset.